SAX circuit simulator#

SAX is a circuit solver written in JAX, writing your component models in SAX enables you not only to get the function values but the gradients, this is useful for circuit optimization.

This tutorial has been adapted from the SAX Quick Start.

You can install sax with pip (read the SAX install instructions here)

! pip install sax
[1]:
from tqdm import trange
from tqdm.notebook import tqdm, trange

from numpy.fft import fft2, fftfreq, fftshift, ifft2
from typing import List
from functools import partial
import sys
import itertools
from pprint import pprint
from omegaconf import OmegaConf
import numpy as np
from sklearn.linear_model import LinearRegression
from scipy import constants
import jax.example_libraries.optimizers as opt
import jax.numpy as jnp
import jax

import gdsfactory.simulation.gtidy3d as gt
from gdsfactory.simulation.get_sparameters_path import get_sparameters_path_lumerical
import gdsfactory as gf
import gdsfactory.simulation.sax as gs
from gdsfactory.get_netlist import get_netlist as _get_netlist

import matplotlib.pyplot as plt
import sax

import sys
import logging
from rich.logging import RichHandler
from gdsfactory.generic_tech import get_generic_pdk

gf.config.rich_output()
PDK = get_generic_pdk()
PDK.activate()

logger = logging.getLogger()
logger.removeHandler(sys.stderr)
logging.basicConfig(level="WARNING", datefmt="[%X]", handlers=[RichHandler()])

gf.config.set_plot_options(show_subports=False)
2023-02-20 17:56:37.114 | INFO     | gdsfactory.config:<module>:50 - Load '/home/runner/work/gdsfactory/gdsfactory/gdsfactory' 6.43.1
[17:56:38] INFO     Using client version: 1.8.4                                                     __init__.py:120
2023-02-20 17:56:38.980 | INFO     | gdsfactory.simulation.gtidy3d:<module>:54 - Tidy3d '1.8.4' installed at ['/usr/share/miniconda/envs/anaconda-client-env/lib/python3.9/site-packages/tidy3d']
2023-02-20 17:56:39.349 | INFO     | gdsfactory.technology.layer_views:__init__:785 - Importing LayerViews from YAML file: /home/runner/work/gdsfactory/gdsfactory/gdsfactory/generic_tech/layer_views.yaml.
2023-02-20 17:56:39.356 | INFO     | gdsfactory.pdk:activate:206 - 'generic' PDK is now active

Scatter dictionaries#

The core datastructure for specifying scatter parameters in SAX is a dictionary… more specifically a dictionary which maps a port combination (2-tuple) to a scatter parameter (or an array of scatter parameters when considering multiple wavelengths for example). Such a specific dictionary mapping is called ann SDict in SAX (SDict ≈ Dict[Tuple[str,str], float]).

Dictionaries are in fact much better suited for characterizing S-parameters than, say, (jax-)numpy arrays due to the inherent sparse nature of scatter parameters. Moreover, dictionaries allow for string indexing, which makes them much more pleasant to use in this context.

o2            o3
   \        /
    ========
   /        \
o1            o4
[2]:
coupling = 0.5
kappa = coupling**0.5
tau = (1 - coupling) ** 0.5
coupler_dict = {
    ("o1", "o4"): tau,
    ("o4", "o1"): tau,
    ("o1", "o3"): 1j * kappa,
    ("o3", "o1"): 1j * kappa,
    ("o2", "o4"): 1j * kappa,
    ("o4", "o2"): 1j * kappa,
    ("o2", "o3"): tau,
    ("o3", "o2"): tau,
}
coupler_dict
{
    ('o1', 'o4'): 0.7071067811865476,
    ('o4', 'o1'): 0.7071067811865476,
    ('o1', 'o3'): 0.7071067811865476j,
    ('o3', 'o1'): 0.7071067811865476j,
    ('o2', 'o4'): 0.7071067811865476j,
    ('o4', 'o2'): 0.7071067811865476j,
    ('o2', 'o3'): 0.7071067811865476,
    ('o3', 'o2'): 0.7071067811865476
}

it can still be tedious to specify every port in the circuit manually. SAX therefore offers the reciprocal function, which auto-fills the reverse connection if the forward connection exist. For example:

[3]:
coupler_dict = sax.reciprocal(
    {
        ("o1", "o4"): tau,
        ("o1", "o3"): 1j * kappa,
        ("o2", "o4"): 1j * kappa,
        ("o2", "o3"): tau,
    }
)

coupler_dict
{
    ('o1', 'o4'): 0.7071067811865476,
    ('o1', 'o3'): 0.7071067811865476j,
    ('o2', 'o4'): 0.7071067811865476j,
    ('o2', 'o3'): 0.7071067811865476,
    ('o4', 'o1'): 0.7071067811865476,
    ('o3', 'o1'): 0.7071067811865476j,
    ('o4', 'o2'): 0.7071067811865476j,
    ('o3', 'o2'): 0.7071067811865476
}

Parametrized Models#

Constructing such an SDict is easy, however, usually we’re more interested in having parametrized models for our components. To parametrize the coupler SDict, just wrap it in a function to obtain a SAX Model, which is a keyword-only function mapping to an SDict:

[4]:
def coupler(coupling=0.5) -> sax.SDict:
    kappa = coupling**0.5
    tau = (1 - coupling) ** 0.5
    return sax.reciprocal(
        {
            ("o1", "o4"): tau,
            ("o1", "o3"): 1j * kappa,
            ("o2", "o4"): 1j * kappa,
            ("o2", "o3"): tau,
        }
    )


coupler(coupling=0.3)
{
    ('o1', 'o4'): 0.8366600265340756,
    ('o1', 'o3'): 0.5477225575051661j,
    ('o2', 'o4'): 0.5477225575051661j,
    ('o2', 'o3'): 0.8366600265340756,
    ('o4', 'o1'): 0.8366600265340756,
    ('o3', 'o1'): 0.5477225575051661j,
    ('o4', 'o2'): 0.5477225575051661j,
    ('o3', 'o2'): 0.8366600265340756
}
[5]:
def waveguide(wl=1.55, wl0=1.55, neff=2.34, ng=3.4, length=10.0, loss=0.0) -> sax.SDict:
    dwl = wl - wl0
    dneff_dwl = (ng - neff) / wl0
    neff = neff - dwl * dneff_dwl
    phase = 2 * jnp.pi * neff * length / wl
    transmission = 10 ** (-loss * length / 20) * jnp.exp(1j * phase)
    return sax.reciprocal(
        {
            ("o1", "o2"): transmission,
        }
    )

Waveguide model#

You can create a dispersive waveguide model in SAX.

Lets compute the effective index neff and group index ng for a 1550nm 500nm straight waveguide

[6]:
strip = gt.modes.Waveguide(
    wavelength=1.55,
    wg_width=0.5,
    wg_thickness=0.22,
    slab_thickness=0.0,
    ncore="si",
    nclad="sio2",
)
strip.plot_Ex(0)  # TE
2023-02-20 17:56:39.427 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/4a9b4373.npz mode data from file cache.
../../../_images/notebooks_plugins_sax_01_sax_11_1.png
[7]:
neff = strip.neffs[0].real
neff
2.4658897
[8]:
nm = 1e-3
ng = gt.modes.group_index(
    wg_width=500 * nm,
    wavelength=1.55,
    wg_thickness=220 * nm,
    slab_thickness=0 * nm,
    ncore="si",
    nclad="sio2",
)
ng
2023-02-20 17:56:39.768 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/4a9b4373.npz mode data from file cache.
2023-02-20 17:56:39.790 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/ba607431.npz mode data from file cache.
2023-02-20 17:56:39.812 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/a7a6b1cb.npz mode data from file cache.
4.1703596115112305
[9]:
straight_sc = gf.partial(gs.models.straight, neff=neff, ng=ng)
[10]:
gs.plot_model(straight_sc)
plt.ylim(-1, 1)
[17:56:39] INFO     Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver   xla_bridge.py:421
                    in registry given worker:                                                                      
           INFO     Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no     xla_bridge.py:421
                    attribute 'GpuAllocatorConfig'                                                                 
           INFO     Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no     xla_bridge.py:421
                    attribute 'GpuAllocatorConfig'                                                                 
           INFO     Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not      xla_bridge.py:421
                    available.                                                                                     
           INFO     Unable to initialize backend 'plugin': xla_extension has no attributes named  xla_bridge.py:421
                    get_plugin_device_client. Compile TensorFlow with                                              
                    //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults                    
                    to false) to enable this.                                                                      
           WARNING  No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun  xla_bridge.py:428
                    for more info.)                                                                                
(-1.0, 1.0)
../../../_images/notebooks_plugins_sax_01_sax_15_7.png
[11]:
gs.plot_model(straight_sc, phase=True)
../../../_images/notebooks_plugins_sax_01_sax_16_0.png

Coupler model#

[12]:
c = gf.components.coupler(length=10, gap=0.2)
c
coupler_gap0p2_length10: uid 611fd498, ports ['o1', 'o2', 'o3', 'o4'], references [], 6 polygons
[13]:
nm = 1e-3
cp = gt.modes.WaveguideCoupler(
    wavelength=1.55,
    wg_width1=500 * nm,
    wg_width2=500 * nm,
    gap=200 * nm,
    wg_thickness=220 * nm,
    slab_thickness=0 * nm,
    ncore="si",
    nclad="sio2",
)
cp.plot_Ex(0, plot_power=False)  # even mode
cp.plot_Ex(1, plot_power=False)  # odd mode
2023-02-20 17:56:41.021 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/f359d64b.npz mode data from file cache.
../../../_images/notebooks_plugins_sax_01_sax_19_1.png
../../../_images/notebooks_plugins_sax_01_sax_19_2.png
[14]:
help(gt.modes.find_coupling_vs_gap)
Help on cython_function_or_method in module gdsfactory.simulation.gtidy3d.modes:

find_coupling_vs_gap(gap1: 'float' = 0.2, gap2: 'float' = 0.4, steps: 'int' = 12, nmodes: 'int' = 4, wavelength: 'float' = 1.55, **kwargs) -> 'pd.DataFrame'
    Returns coupling vs gap pandas DataFrame.

    Args:
        gap1: starting gap in um.
        gap2: end gap in um.
        steps: number of steps.
        nmodes: number of modes.
        wavelength: wavelength (um).

    Keyword Args:
        wg_width: waveguide width.
        wg_width1: optional left waveguide width in um.
        wg_width2: optional right waveguide width in um.
        wg_thickness: thickness waveguide (um).
        ncore: core refractive index.
        nclad: cladding refractive index.
        slab_thickness: thickness slab (um).
        t_box: thickness BOX (um).
        t_clad: thickness cladding (um).
        xmargin: margin from waveguide edge to each side (um).
        resolution: pixels/um. Can be a single number or tuple (x, y).
        bend_radius: optional bend radius (um).
        cache: True uses file cache from PDK.modes_path. False skips cache.

[15]:
df = gt.modes.find_coupling_vs_gap(
    wg_width1=500 * nm,
    wg_width2=500 * nm,
    wg_thickness=220 * nm,
    slab_thickness=0 * nm,
    ncore="si",
    nclad="sio2",
)
df
2023-02-20 17:56:41.646 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/f359d64b.npz mode data from file cache.
2023-02-20 17:56:41.666 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/fcb86125.npz mode data from file cache.
2023-02-20 17:56:41.686 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/aaf07c4a.npz mode data from file cache.
2023-02-20 17:56:41.707 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/31b76c44.npz mode data from file cache.
2023-02-20 17:56:41.728 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/2f4e3e77.npz mode data from file cache.
2023-02-20 17:56:41.749 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/33e70421.npz mode data from file cache.
2023-02-20 17:56:41.770 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/fad263ef.npz mode data from file cache.
2023-02-20 17:56:41.790 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/33b49974.npz mode data from file cache.
2023-02-20 17:56:41.810 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/e4bc7a19.npz mode data from file cache.
2023-02-20 17:56:41.830 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/d1c02ec7.npz mode data from file cache.
2023-02-20 17:56:41.852 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/708b93c3.npz mode data from file cache.
2023-02-20 17:56:41.872 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/932ed1c4.npz mode data from file cache.
[15]:
gap ne no lc dn
0 0.200000 2.430307 2.403961 29.415999 0.026346
1 0.218182 2.434447 2.408951 30.397202 0.025496
2 0.236364 2.460559 2.448616 64.893606 0.011943
3 0.254545 2.464516 2.452997 67.281800 0.011519
4 0.272727 2.430179 2.416353 56.053277 0.013826
5 0.290909 2.444588 2.432067 61.894695 0.012521
6 0.309091 2.435488 2.428164 105.820223 0.007324
7 0.327273 2.439679 2.432644 110.152003 0.007036
8 0.345455 2.443704 2.437016 115.889536 0.006687
9 0.363636 2.447752 2.441331 120.682591 0.006422
10 0.381818 2.439935 2.436173 206.007073 0.003762
11 0.400000 2.454328 2.450968 230.701604 0.003359

For a 200nm gap the effective index difference dn is 0.026, which means that there is 100% power coupling over 29.4

[16]:
coupler_sc = gf.partial(gs.models.coupler, dn=0.026, length=0, coupling0=0)
gs.plot_model(coupler_sc)
../../../_images/notebooks_plugins_sax_01_sax_23_0.png

If we ignore the coupling from the bend coupling0 = 0 we know that for a 3dB coupling we need half of the lc length, which is the length needed to coupler 100% of power.

[17]:
coupler_sc = gf.partial(gs.models.coupler, dn=0.026, length=29.4 / 2, coupling0=0)
gs.plot_model(coupler_sc)
../../../_images/notebooks_plugins_sax_01_sax_25_0.png

FDTD Sparameters model#

You can also fit a model from Sparameter FDTD simulation data from tidy3d, Lumerical or MEEP.

[18]:
filepath = get_sparameters_path_lumerical(gf.c.mmi1x2)
mmi1x2 = gs.read.model_from_npz(filepath=filepath)
gs.plot_model(mmi1x2)
../../../_images/notebooks_plugins_sax_01_sax_27_0.png
[19]:
gs.plot_model(mmi1x2, ports2=("o2", "o3"))
../../../_images/notebooks_plugins_sax_01_sax_28_0.png

Model fit#

You can fit a sax model to Sparameter FDTD simulation data.

[20]:
coupler_fdtd = gs.read.model_from_csv(
    filepath=gf.config.sparameters_path / "coupler" / "coupler_G224n_L20_S220.csv",
    xkey="wavelength_nm",
    prefix="S",
    xunits=1e-3,
)
[21]:
gs.plot_model(coupler_fdtd)
../../../_images/notebooks_plugins_sax_01_sax_31_0.png
[22]:
# lumerical simulation (different coupler)
coupler_fdtd = gs.read.model_from_csv(
    filepath=gf.config.sparameters_path / "coupler_057254c0_00cc8908.csv",
)
gs.plot_model(coupler_fdtd)
../../../_images/notebooks_plugins_sax_01_sax_32_0.png

Lets fit the coupler spectrum with a linear regression sklearn fit

[23]:
f = jnp.linspace(constants.c / 1.0e-6, constants.c / 2.0e-6, 500) * 1e-12  # THz
wl = constants.c / (f * 1e12) * 1e6  # um

filepath = gf.config.sparameters_path / "coupler" / "coupler_G224n_L20_S220.csv"
coupler_fdtd = gs.read.model_from_csv(
    filepath, xkey="wavelength_nm", prefix="S", xunits=1e-3
)
sd = coupler_fdtd(wl=wl)

k = sd["o1", "o3"]
t = sd["o1", "o4"]
s = t + k
a = t - k

Lets fit the symmetric (t+k) and antisymmetric (t-k) transmission

Symmetric#

[24]:
plt.plot(wl, jnp.abs(s))
plt.grid(True)
plt.xlabel("Frequency [THz]")
plt.ylabel("Transmission")
plt.title("symmetric (transmission + coupling)")
plt.legend()
plt.show()
[17:56:51] WARNING  No artists with labels found to put in legend.  Note that artists whose label    legend.py:1330
                    start with an underscore are ignored when legend() is called with no argument.                 
../../../_images/notebooks_plugins_sax_01_sax_36_1.png
[25]:
plt.plot(wl, jnp.abs(a))
plt.grid(True)
plt.xlabel("Frequency [THz]")
plt.ylabel("Transmission")
plt.title("anti-symmetric (transmission - coupling)")
plt.legend()
plt.show()
           WARNING  No artists with labels found to put in legend.  Note that artists whose label    legend.py:1330
                    start with an underscore are ignored when legend() is called with no argument.                 
../../../_images/notebooks_plugins_sax_01_sax_37_1.png
[26]:
r = LinearRegression()


def fX(x, _order=8):
    return x[:, None] ** (
        jnp.arange(_order)[None, :]
    )  # artificially create more 'features' (wl**2, wl**3, wl**4, ...)


X = fX(wl)
r.fit(X, jnp.abs(s))
asm, bsm = r.coef_, r.intercept_


def fsm(x):
    return fX(x) @ asm + bsm  # fit symmetric module fiir


plt.plot(wl, jnp.abs(s))
plt.plot(wl, fsm(wl))
plt.grid(True)
plt.xlabel("Frequency [THz]")
plt.ylabel("Transmission")
plt.legend()
plt.show()
[17:56:52] WARNING  No artists with labels found to put in legend.  Note that artists whose label    legend.py:1330
                    start with an underscore are ignored when legend() is called with no argument.                 
../../../_images/notebooks_plugins_sax_01_sax_38_1.png
[27]:
r = LinearRegression()
r.fit(X, jnp.unwrap(jnp.angle(s)))
asp, bsp = r.coef_, r.intercept_


def fsp(x):
    return fX(x) @ asp + bsp  # fit symmetric phase


plt.plot(wl, jnp.unwrap(jnp.angle(s)))
plt.plot(wl, fsp(wl))
plt.grid(True)
plt.xlabel("Frequency [THz]")
plt.ylabel("Angle [rad]")
plt.legend()
plt.show()
           WARNING  No artists with labels found to put in legend.  Note that artists whose label    legend.py:1330
                    start with an underscore are ignored when legend() is called with no argument.                 
../../../_images/notebooks_plugins_sax_01_sax_39_1.png
[28]:
def fs(x):
    return fsm(x) * jnp.exp(1j * fsp(x))

Lets fit the symmetric (t+k) and antisymmetric (t-k) transmission

Anti-Symmetric#

[29]:
r = LinearRegression()
r.fit(X, jnp.abs(a))
aam, bam = r.coef_, r.intercept_


def fam(x):
    return fX(x) @ aam + bam


plt.plot(wl, jnp.abs(a))
plt.plot(wl, fam(wl))
plt.grid(True)
plt.xlabel("Frequency [THz]")
plt.ylabel("Transmission")
plt.legend()
plt.show()
           WARNING  No artists with labels found to put in legend.  Note that artists whose label    legend.py:1330
                    start with an underscore are ignored when legend() is called with no argument.                 
../../../_images/notebooks_plugins_sax_01_sax_42_1.png
[30]:
r = LinearRegression()
r.fit(X, jnp.unwrap(jnp.angle(a)))
aap, bap = r.coef_, r.intercept_


def fap(x):
    return fX(x) @ aap + bap


plt.plot(wl, jnp.unwrap(jnp.angle(a)))
plt.plot(wl, fap(wl))
plt.grid(True)
plt.xlabel("Frequency [THz]")
plt.ylabel("Angle [rad]")
plt.legend()
plt.show()
           WARNING  No artists with labels found to put in legend.  Note that artists whose label    legend.py:1330
                    start with an underscore are ignored when legend() is called with no argument.                 
../../../_images/notebooks_plugins_sax_01_sax_43_1.png
[31]:
def fa(x):
    return fam(x) * jnp.exp(1j * fap(x))

Total#

[32]:
t_ = 0.5 * (fs(wl) + fa(wl))

plt.plot(wl, jnp.abs(t))
plt.plot(wl, jnp.abs(t_))
plt.xlabel("Frequency [THz]")
plt.ylabel("Transmission")
../../../_images/notebooks_plugins_sax_01_sax_46_0.png
[33]:
k_ = 0.5 * (fs(wl) - fa(wl))

plt.plot(wl, jnp.abs(k))
plt.plot(wl, jnp.abs(k_))
plt.xlabel("Frequency [THz]")
plt.ylabel("Coupling")
../../../_images/notebooks_plugins_sax_01_sax_47_0.png
[34]:
@jax.jit
def coupler(wl=1.5):
    wl = jnp.asarray(wl)
    wl_shape = wl.shape
    wl = wl.ravel()
    t = (0.5 * (fs(wl) + fa(wl))).reshape(*wl_shape)
    k = (0.5 * (fs(wl) - fa(wl))).reshape(*wl_shape)
    sdict = {
        ("o1", "o4"): t,
        ("o1", "o3"): k,
        ("o2", "o3"): k,
        ("o2", "o4"): t,
    }
    return sax.reciprocal(sdict)
[35]:
f = jnp.linspace(constants.c / 1.0e-6, constants.c / 2.0e-6, 500) * 1e-12  # THz
wl = constants.c / (f * 1e12) * 1e6  # um

filepath = gf.config.sparameters_path / "coupler" / "coupler_G224n_L20_S220.csv"
coupler_fdtd = gs.read.model_from_csv(
    filepath, xkey="wavelength_nm", prefix="S", xunits=1e-3
)
sd = coupler_fdtd(wl=wl)
sd_ = coupler(wl=wl)

T = jnp.abs(sd["o1", "o4"]) ** 2
K = jnp.abs(sd["o1", "o3"]) ** 2
T_ = jnp.abs(sd_["o1", "o4"]) ** 2
K_ = jnp.abs(sd_["o1", "o3"]) ** 2
dP = jnp.unwrap(jnp.angle(sd["o1", "o3"]) - jnp.angle(sd["o1", "o4"]))
dP_ = jnp.unwrap(jnp.angle(sd_["o1", "o3"]) - jnp.angle(sd_["o1", "o4"]))

plt.figure(figsize=(12, 3))
plt.plot(wl, T, label="T (fdtd)", c="C0", ls=":", lw="6")
plt.plot(wl, T_, label="T (model)", c="C0")

plt.plot(wl, K, label="K (fdtd)", c="C1", ls=":", lw="6")
plt.plot(wl, K_, label="K (model)", c="C1")

plt.ylim(-0.05, 1.05)
plt.grid(True)

plt.twinx()
plt.plot(wl, dP, label="ΔΦ (fdtd)", color="C2", ls=":", lw="6")
plt.plot(wl, dP_, label="ΔΦ (model)", color="C2")

plt.xlabel("Frequency [THz]")
plt.ylabel("Transmission")
plt.figlegend(bbox_to_anchor=(1.08, 0.9))
plt.savefig("fdtd_vs_model.png", bbox_inches="tight")
plt.show()
../../../_images/notebooks_plugins_sax_01_sax_49_0.png

SAX gdsfactory Compatibility#

From Layout to Circuit Model

If you define your SAX S parameter models for your components, you can directly simulate your circuits from gdsfactory

[36]:
mzi = gf.components.mzi(delta_length=10)
mzi
mzi_delta_length10: uid d21b26ba, ports ['o1', 'o2'], references ['bend_euler_1', 'bend_euler_2', 'bend_euler_3', 'bend_euler_4', 'bend_euler_5', 'bend_euler_6', 'straight_5', 'straight_6', 'straight_7', 'bend_euler_7', 'bend_euler_8', 'straight_8', 'straight_9', 'straight_10', 'sytl', 'syl', 'sxt', 'sxb', 'cp1', 'cp2'], 0 polygons
[37]:
mzi.plot_netlist()
<networkx.classes.graph.Graph object at 0x7f2c2ecc11f0>
../../../_images/notebooks_plugins_sax_01_sax_52_1.png
[38]:
netlist = mzi.get_netlist()
pprint(netlist["connections"])
{'bend_euler_1,o1': 'cp1,o3',
 'bend_euler_1,o2': 'syl,o1',
 'bend_euler_2,o1': 'syl,o2',
 'bend_euler_2,o2': 'sxb,o1',
 'bend_euler_3,o1': 'cp1,o2',
 'bend_euler_3,o2': 'sytl,o1',
 'bend_euler_4,o1': 'sxt,o1',
 'bend_euler_4,o2': 'sytl,o2',
 'bend_euler_5,o1': 'straight_5,o2',
 'bend_euler_5,o2': 'straight_6,o1',
 'bend_euler_6,o1': 'straight_6,o2',
 'bend_euler_6,o2': 'straight_7,o1',
 'bend_euler_7,o1': 'straight_8,o2',
 'bend_euler_7,o2': 'straight_9,o1',
 'bend_euler_8,o1': 'straight_9,o2',
 'bend_euler_8,o2': 'straight_10,o1',
 'cp2,o2': 'straight_7,o2',
 'cp2,o3': 'straight_10,o2',
 'straight_5,o1': 'sxt,o2',
 'straight_8,o1': 'sxb,o2'}

The netlist has three different components:

  1. straight

  2. mmi1x2

  3. bend_euler

You need models for each subcomponents to simulate the Component.

[39]:
def straight(wl=1.5, length=10.0, neff=2.4) -> sax.SDict:
    wl0 = 1.5  # center wavelength for which the waveguide model is defined
    return sax.reciprocal({("o1", "o2"): jnp.exp(2j * jnp.pi * neff * length / wl)})


def mmi1x2():
    """Assumes a perfect 1x2 splitter"""
    return sax.reciprocal(
        {
            ("o1", "o2"): 0.5**0.5,
            ("o1", "o3"): 0.5**0.5,
        }
    )


def bend_euler(wl=1.5, length=20.0):
    """ "Let's assume a reduced transmission for the euler bend compared to a straight"""
    return {k: 0.99 * v for k, v in straight(wl=wl, length=length).items()}


models = {
    "bend_euler": bend_euler,
    "mmi1x2": mmi1x2,
    "straight": straight,
}
[40]:
circuit, _ = sax.circuit(netlist=netlist, models=models)
[41]:
wl = np.linspace(1.5, 1.6)
S = circuit(wl=wl)

plt.figure(figsize=(14, 4))
plt.title("MZI")
plt.plot(1e3 * wl, jnp.abs(S["o1", "o2"]) ** 2)
plt.xlabel("λ [nm]")
plt.ylabel("T")
plt.grid(True)
plt.show()
../../../_images/notebooks_plugins_sax_01_sax_57_0.png
[42]:
mzi = gf.components.mzi(delta_length=20)  # Double the length, reduces FSR by 1/2
mzi
mzi_delta_length20: uid 1af4aeb2, ports ['o1', 'o2'], references ['bend_euler_1', 'bend_euler_2', 'bend_euler_3', 'bend_euler_4', 'bend_euler_5', 'bend_euler_6', 'straight_5', 'straight_6', 'straight_7', 'bend_euler_7', 'bend_euler_8', 'straight_8', 'straight_9', 'straight_10', 'sytl', 'syl', 'sxt', 'sxb', 'cp1', 'cp2'], 0 polygons
[43]:
circuit, _ = sax.circuit(netlist=mzi.get_netlist(), models=models)

wl = np.linspace(1.5, 1.6, 256)
S = circuit(wl=wl)

plt.figure(figsize=(14, 4))
plt.title("MZI")
plt.plot(1e3 * wl, jnp.abs(S["o1", "o2"]) ** 2)
plt.xlabel("λ [nm]")
plt.ylabel("T")
plt.grid(True)
plt.show()
../../../_images/notebooks_plugins_sax_01_sax_59_0.png

Layout aware Monte Carlo#

You can model the manufacturing variations on the performance of photonics thanks to the fast SAX circuit simulator with layout information and wafer maps of waveguide width and layer thickness variations.

The width and height variations can be extracted from:

  • Ring resonators 2017

  • MZI interferometers 2019

Waveguide Model#

To improve the waveguide model you need to find the effective index of the waveguide in relation to its parameters (width and thickness) using an open source mode solver.

[44]:
wavelengths = np.linspace(1.5, 1.6, 10)
widths = np.linspace(0.4, 0.6, 5)

wavelengths, widths = np.mgrid[1.5:1.6:10j, 0.4:0.6:5j]
neffs = np.zeros_like(wavelengths)
neffs_ = neffs.ravel()

for i, (wl, w) in enumerate(zip(tqdm(wavelengths.ravel()), widths.ravel())):
    wg = gt.modes.Waveguide(
        wavelength=wl,
        wg_width=w,
        mode_number=1,
        wg_thickness=0.22,
        slab_thickness=0.0,
        ncore="si",
        nclad="sio2",
    )
    wg.compute_modes()
    neffs_[i] = wg.neffs[0].real
2023-02-20 17:57:01.610 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/734e2ac6.npz mode data from file cache.
2023-02-20 17:57:01.630 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/edd73dca.npz mode data from file cache.
2023-02-20 17:57:01.650 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/a6c07d9e.npz mode data from file cache.
2023-02-20 17:57:01.670 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/955c203e.npz mode data from file cache.
2023-02-20 17:57:01.690 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/2dadd7c5.npz mode data from file cache.
2023-02-20 17:57:01.714 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/e87e75a9.npz mode data from file cache.
2023-02-20 17:57:01.734 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/63e7ddd5.npz mode data from file cache.
2023-02-20 17:57:01.754 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/de114284.npz mode data from file cache.
2023-02-20 17:57:01.774 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/71d4b4ff.npz mode data from file cache.
2023-02-20 17:57:01.793 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/c1215583.npz mode data from file cache.
2023-02-20 17:57:01.813 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/cb8443f9.npz mode data from file cache.
2023-02-20 17:57:01.833 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/6a0da6a1.npz mode data from file cache.
2023-02-20 17:57:01.853 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/335a02e5.npz mode data from file cache.
2023-02-20 17:57:01.872 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/abcd4522.npz mode data from file cache.
2023-02-20 17:57:01.892 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/5d6f135d.npz mode data from file cache.
2023-02-20 17:57:01.912 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/8d0531c7.npz mode data from file cache.
2023-02-20 17:57:01.934 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/9273d450.npz mode data from file cache.
2023-02-20 17:57:01.953 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/37802538.npz mode data from file cache.
2023-02-20 17:57:01.973 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/e01e51b3.npz mode data from file cache.
2023-02-20 17:57:01.993 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/7d1a39f6.npz mode data from file cache.
2023-02-20 17:57:02.012 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/c9925efc.npz mode data from file cache.
2023-02-20 17:57:02.032 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/12af6142.npz mode data from file cache.
2023-02-20 17:57:02.053 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/64ce0967.npz mode data from file cache.
2023-02-20 17:57:02.073 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/490fefa7.npz mode data from file cache.
2023-02-20 17:57:02.092 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/c88acc24.npz mode data from file cache.
2023-02-20 17:57:02.112 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/00a50ba5.npz mode data from file cache.
2023-02-20 17:57:02.131 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/5ebba308.npz mode data from file cache.
2023-02-20 17:57:02.151 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/d0efbf50.npz mode data from file cache.
2023-02-20 17:57:02.173 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/ca7dd62b.npz mode data from file cache.
2023-02-20 17:57:02.193 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/ef4bb880.npz mode data from file cache.
2023-02-20 17:57:02.212 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/6edbe277.npz mode data from file cache.
2023-02-20 17:57:02.232 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/2aed78d7.npz mode data from file cache.
2023-02-20 17:57:02.252 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/7813c640.npz mode data from file cache.
2023-02-20 17:57:02.273 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/1d0a50a1.npz mode data from file cache.
2023-02-20 17:57:02.294 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/079bfbde.npz mode data from file cache.
2023-02-20 17:57:02.313 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/936aa720.npz mode data from file cache.
2023-02-20 17:57:02.333 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/655e4b64.npz mode data from file cache.
2023-02-20 17:57:02.353 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/52284c76.npz mode data from file cache.
2023-02-20 17:57:02.375 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/f0c6cfbd.npz mode data from file cache.
2023-02-20 17:57:02.395 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/a7c7088e.npz mode data from file cache.
2023-02-20 17:57:02.415 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/9d302853.npz mode data from file cache.
2023-02-20 17:57:02.439 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/14217c5e.npz mode data from file cache.
2023-02-20 17:57:02.460 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/b929eaf5.npz mode data from file cache.
2023-02-20 17:57:02.482 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/e5c9959d.npz mode data from file cache.
2023-02-20 17:57:02.502 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/84bb7088.npz mode data from file cache.
2023-02-20 17:57:02.522 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/5b191282.npz mode data from file cache.
2023-02-20 17:57:02.541 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/dfb1a7b7.npz mode data from file cache.
2023-02-20 17:57:02.561 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/21b94bc6.npz mode data from file cache.
2023-02-20 17:57:02.582 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/5b6ba947.npz mode data from file cache.
2023-02-20 17:57:02.602 | INFO     | gdsfactory.simulation.gtidy3d.modes:compute_modes:439 - load /home/runner/.gdsfactory/modes/4cd60169.npz mode data from file cache.
[45]:
plt.pcolormesh(wavelengths, widths, neffs)
plt.xlabel("λ [μm]")
plt.ylabel("width [μm]")
plt.colorbar()
plt.show()
../../../_images/notebooks_plugins_sax_01_sax_63_0.png
[46]:
_grid = [jnp.sort(jnp.unique(wavelengths)), jnp.sort(jnp.unique(widths))]
_data = jnp.asarray(neffs)


@jax.jit
def _get_coordinate(arr1d: jnp.ndarray, value: jnp.ndarray):
    return jnp.interp(value, arr1d, jnp.arange(arr1d.shape[0]))


@jax.jit
def _get_coordinates(arrs1d: List[jnp.ndarray], values: jnp.ndarray):
    # don't use vmap as arrays in arrs1d could have different shapes...
    return jnp.array([_get_coordinate(a, v) for a, v in zip(arrs1d, values)])


@jax.jit
def neff(wl=1.55, width=0.5):
    params = jnp.stack(jnp.broadcast_arrays(jnp.asarray(wl), jnp.asarray(width)), 0)
    coords = _get_coordinates(_grid, params)
    return jax.scipy.ndimage.map_coordinates(_data, coords, 1, mode="nearest")


neff(wl=[1.52, 1.58], width=[0.5, 0.55])
Array([2.4988024, 2.4613397], dtype=float32)
[47]:
wavelengths_ = np.linspace(wavelengths.min(), wavelengths.max(), 100)
widths_ = np.linspace(widths.min(), widths.max(), 100)
wavelengths_, widths_ = np.meshgrid(wavelengths_, widths_)
neffs_ = neff(wavelengths_, widths_)
plt.pcolormesh(wavelengths_, widths_, neffs_)
plt.xlabel("λ [μm]")
plt.ylabel("width [μm]")
plt.colorbar()
plt.show()
../../../_images/notebooks_plugins_sax_01_sax_65_0.png
[48]:
def straight(wl=1.55, length=10.0, width=0.5):
    S = {
        ("o1", "o2"): jnp.exp(2j * np.pi * neff(wl=wl, width=width) / wl * length),
    }
    return sax.reciprocal(S)


def mmi1x2():
    """Assumes a perfect 1x2 splitter"""
    return sax.reciprocal(
        {
            ("o1", "o2"): 0.5**0.5,
            ("o1", "o3"): 0.5**0.5,
        }
    )


def mmi2x2():
    S = {
        ("o1", "o3"): 0.5**0.5,
        ("o1", "o4"): 1j * 0.5**0.5,
        ("o2", "o3"): 1j * 0.5**0.5,
        ("o2", "o4"): 0.5**0.5,
    }
    return sax.reciprocal(S)


def bend_euler(wl=1.5, length=20.0, width=0.5):
    """ "Let's assume a reduced transmission for the euler bend compared to a straight"""
    return {k: 0.99 * v for k, v in straight(wl=wl, length=length, width=width).items()}


models = {
    "bend_euler": bend_euler,
    "mmi1x2": mmi1x2,
    "mmi2x2": mmi2x2,
    "straight": straight,
}

Even though this still is lossless transmission, we’re at least modeling the phase correctly.

[49]:
straight()
{
    ('o1', 'o2'): Array(0.8401692-0.5423244j, dtype=complex64),
    ('o2', 'o1'): Array(0.8401692-0.5423244j, dtype=complex64)
}
[50]:
circuit, _ = sax.circuit(mzi.get_netlist(), models=models)
circuit()
{
    ('o1', 'o1'): Array(0.+0.j, dtype=complex64),
    ('o2', 'o2'): Array(0.+0.j, dtype=complex64),
    ('o1', 'o2'): Array(0.792619+0.15199116j, dtype=complex64),
    ('o2', 'o1'): Array(0.792619+0.15199116j, dtype=complex64)
}
[51]:
wl = jnp.linspace(1.51, 1.59, 1000)
S = circuit(wl=wl)
plt.plot(wl, abs(S["o1", "o2"]) ** 2)
plt.ylim(-0.05, 1.05)
plt.xlabel("λ [μm]")
plt.ylabel("T")
plt.ylim(-0.05, 1.05)
plt.grid(True)
plt.show()
../../../_images/notebooks_plugins_sax_01_sax_70_0.png

Circuit model with variability#

Let’s assume the waveguide width changes with a certain correlation length. We can create a ‘wafermap’ of width variations by randomly varying the width and low pass filtering with a spatial frequency being the inverse of the correlation length. There are probably better ways to do this, but this works for this tutorial.

[52]:
def create_wafermaps(placements, correlation_length=1.0, num_maps=1, mean=0.0, std=1.0):
    dx = dy = correlation_length / 200
    xs, ys = [p["x"] for p in placements.values()], [
        p["y"] for p in placements.values()
    ]
    xmin, xmax, ymin, ymax = min(xs), max(xs), min(ys), max(ys)
    wx, wy = xmax - xmin, ymax - ymin
    xmin, xmax, ymin, ymax = xmin - wx, xmax + wx, ymin - wy, ymax + wy
    x, y = np.arange(xmin, xmax + dx, dx), np.arange(ymin, ymax + dy, dy)
    W0 = np.random.randn(num_maps, x.shape[0], y.shape[0])

    fx, fy = fftshift(fftfreq(x.shape[0], d=x[1] - x[0])), fftshift(
        fftfreq(y.shape[0], d=y[1] - y[0])
    )
    fY, fX = np.meshgrid(fy, fx)
    fW = fftshift(fft2(W0))

    if correlation_length >= min(x.shape[0], y.shape[0]):
        fW = np.zeros_like(fW)
    else:
        fW = np.where(np.sqrt(fX**2 + fY**2)[None] > 1 / correlation_length, 0, fW)

    W = np.abs(fftshift(ifft2(fW))) ** 2
    mean_ = W.mean(1, keepdims=True).mean(2, keepdims=True)
    std_ = W.std(1, keepdims=True).std(2, keepdims=True)
    if (std_ == 0).all():
        std_ = 1

    W = (W - mean_) / std_
    W = W * std + mean
    return x, y, W
[53]:
placements = mzi.get_netlist()["placements"]
xm, ym, wmaps = create_wafermaps(
    placements, correlation_length=100, mean=0.5, std=0.002, num_maps=100
)

for i, wmap in enumerate(wmaps):
    plt.imshow(wmap, cmap="RdBu")
    plt.show()
    if i == 2:
        break
../../../_images/notebooks_plugins_sax_01_sax_73_0.png
../../../_images/notebooks_plugins_sax_01_sax_73_1.png
../../../_images/notebooks_plugins_sax_01_sax_73_2.png
[54]:
def widths(xw, yw, wmaps, x, y):
    _wmap_grid = [xw, yw]
    params = jnp.stack(jnp.broadcast_arrays(jnp.asarray(x), jnp.asarray(y)), 0)
    coords = _get_coordinates(_wmap_grid, params)

    map_coordinates = partial(
        jax.scipy.ndimage.map_coordinates, coordinates=coords, order=1, mode="nearest"
    )
    w = jax.vmap(map_coordinates)(wmaps)
    return w

Let’s now sample the MZI width variation on the wafer map (let’s assume a single width variation per point):

Simple MZI#

[55]:
@gf.cell
def simple_mzi():
    global bend_top1_
    c = gf.Component()

    # instances
    mmi_in = gf.components.mmi1x2()
    mmi_out = gf.components.mmi2x2()
    bend = gf.components.bend_euler()
    half_delay_straight = gf.components.straight(length=10.0)

    # references (sax convention: vars ending in underscore are references)
    mmi_in_ = c << mmi_in
    mmi_out_ = c << mmi_out
    straight_top1_ = c << half_delay_straight
    straight_top2_ = c << half_delay_straight
    bend_top1_ = c << bend
    bend_top2_ = (c << bend).mirror()
    bend_top3_ = (c << bend).mirror()
    bend_top4_ = c << bend
    bend_btm1_ = (c << bend).mirror()
    bend_btm2_ = c << bend
    bend_btm3_ = c << bend
    bend_btm4_ = (c << bend).mirror()

    # connections
    bend_top1_.connect("o1", mmi_in_.ports["o2"])
    straight_top1_.connect("o1", bend_top1_.ports["o2"])
    bend_top2_.connect("o1", straight_top1_.ports["o2"])
    bend_top3_.connect("o1", bend_top2_.ports["o2"])
    straight_top2_.connect("o1", bend_top3_.ports["o2"])
    bend_top4_.connect("o1", straight_top2_.ports["o2"])

    bend_btm1_.connect("o1", mmi_in_.ports["o3"])
    bend_btm2_.connect("o1", bend_btm1_.ports["o2"])
    bend_btm3_.connect("o1", bend_btm2_.ports["o2"])
    bend_btm4_.connect("o1", bend_btm3_.ports["o2"])

    mmi_out_.connect("o1", bend_btm4_.ports["o2"])

    # ports
    c.add_port(
        "o1",
        port=mmi_in_.ports["o1"],
    )
    c.add_port("o2", port=mmi_out_.ports["o3"])
    c.add_port("o3", port=mmi_out_.ports["o4"])
    return c


mzi = simple_mzi()
mzi
simple_mzi: uid 42bf20ee, ports ['o1', 'o2', 'o3'], references ['mmi1x2_1', 'mmi2x2_1', 'straight_1', 'straight_2', 'bend_euler_1', 'bend_euler_2', 'bend_euler_3', 'bend_euler_4', 'bend_euler_5', 'bend_euler_6', 'bend_euler_7', 'bend_euler_8'], 0 polygons
[56]:
circuit, _ = sax.circuit(mzi.get_netlist(), models=models)
[57]:
mzi_params = sax.get_settings(circuit)
placements = mzi.get_netlist()["placements"]
width_params = {
    k: {"width": widths(xm, ym, wmaps, v["x"], v["y"])}
    for k, v in placements.items()
    if "width" in mzi_params[k]
}

S0 = circuit(wl=wl)
S = circuit(
    wl=wl[:, None],
    **width_params,
)
ps = plt.plot(wl * 1e3, abs(S["o1", "o2"]) ** 2, color="C0", lw=1, alpha=0.1)
nps = plt.plot(wl * 1e3, abs(S0["o1", "o2"]) ** 2, color="C1", lw=2, alpha=1)
plt.xlabel("λ [nm]")
plt.ylabel("T")
plt.plot([1550, 1550], [-1, 2], color="black", ls=":")
plt.ylim(-0.05, 1.05)
plt.grid(True)
plt.figlegend([*ps[-1:], *nps], ["MC", "nominal"], bbox_to_anchor=(1.1, 0.9))
rmse = jnp.mean(
    jnp.abs(jnp.abs(S["o1", "o2"]) ** 2 - jnp.abs(S0["o1", "o2"][:, None]) ** 2) ** 2
)
plt.title(f"{rmse=}")
plt.show()
../../../_images/notebooks_plugins_sax_01_sax_78_0.png

Compact MZI#

Let’s see if we can improve variability (i.e. the RMSE w.r.t. nominal) by making the MZI more compact:

[58]:
@gf.cell
def compact_mzi():
    c = gf.Component()

    # instances
    mmi_in = gf.components.mmi1x2()
    mmi_out = gf.components.mmi2x2()
    bend = gf.components.bend_euler()
    half_delay_straight = gf.components.straight()
    middle_straight = gf.components.straight(length=6.0)
    half_middle_straight = gf.components.straight(3.0)

    # references (sax convention: vars ending in underscore are references)
    mmi_in_ = c << mmi_in

    bend_top1_ = c << bend
    straight_top1_ = c << half_delay_straight
    bend_top2_ = (c << bend).mirror()
    straight_top2_ = c << middle_straight
    bend_top3_ = (c << bend).mirror()
    straight_top3_ = c << half_delay_straight
    bend_top4_ = c << bend

    straight_btm1_ = c << half_middle_straight
    bend_btm1_ = c << bend
    bend_btm2_ = (c << bend).mirror()
    bend_btm3_ = (c << bend).mirror()
    bend_btm4_ = c << bend
    straight_btm2_ = c << half_middle_straight

    mmi_out_ = c << mmi_out

    # connections
    bend_top1_.connect("o1", mmi_in_.ports["o2"])
    straight_top1_.connect("o1", bend_top1_.ports["o2"])
    bend_top2_.connect("o1", straight_top1_.ports["o2"])
    straight_top2_.connect("o1", bend_top2_.ports["o2"])
    bend_top3_.connect("o1", straight_top2_.ports["o2"])
    straight_top3_.connect("o1", bend_top3_.ports["o2"])
    bend_top4_.connect("o1", straight_top3_.ports["o2"])

    straight_btm1_.connect("o1", mmi_in_.ports["o3"])
    bend_btm1_.connect("o1", straight_btm1_.ports["o2"])
    bend_btm2_.connect("o1", bend_btm1_.ports["o2"])
    bend_btm3_.connect("o1", bend_btm2_.ports["o2"])
    bend_btm4_.connect("o1", bend_btm3_.ports["o2"])
    straight_btm2_.connect("o1", bend_btm4_.ports["o2"])

    mmi_out_.connect("o1", straight_btm2_.ports["o2"])

    # ports
    c.add_port(
        "o1",
        port=mmi_in_.ports["o1"],
    )
    c.add_port("o2", port=mmi_out_.ports["o3"])
    c.add_port("o3", port=mmi_out_.ports["o4"])
    return c
[59]:
compact_mzi1 = compact_mzi()
fig = compact_mzi1.plot()
placements = compact_mzi1.get_netlist()["placements"]
mzi3, _ = sax.circuit(compact_mzi1.get_netlist(), models=models)
../../../_images/notebooks_plugins_sax_01_sax_81_0.png
[60]:
mzi_params = sax.get_settings(mzi3)
placements = compact_mzi1.get_netlist()["placements"]
width_params = {
    k: {"width": widths(xm, ym, wmaps, v["x"], v["y"])}
    for k, v in placements.items()
    if "width" in mzi_params[k]
}

S0 = mzi3(wl=wl)
S = mzi3(
    wl=wl[:, None],
    **width_params,
)
ps = plt.plot(wl * 1e3, abs(S["o1", "o2"]) ** 2, color="C0", lw=1, alpha=0.1)
nps = plt.plot(wl * 1e3, abs(S0["o1", "o2"]) ** 2, color="C1", lw=2, alpha=1)
plt.xlabel("λ [nm]")
plt.ylabel("T")
plt.plot([1550, 1550], [-1, 2], color="black", ls=":")
plt.ylim(-0.05, 1.05)
plt.grid(True)
plt.figlegend([*ps[-1:], *nps], ["MC", "nominal"], bbox_to_anchor=(1.1, 0.9))
rmse = jnp.mean(
    jnp.abs(jnp.abs(S["o1", "o2"]) ** 2 - jnp.abs(S0["o1", "o2"][:, None]) ** 2) ** 2
)
plt.title(f"{rmse=}")
plt.show()
../../../_images/notebooks_plugins_sax_01_sax_82_0.png

Phase shifter model#

You can create a phase shifter model that depends on the applied volage. For that you need first to figure out what’s the phase shift for different voltages.

[61]:
delta_length = 10
mzi_component = gf.components.mzi_phase_shifter_top_heater_metal(
    delta_length=delta_length
)
mzi_component
mzi_713867c1: uid 2460af46, ports ['o1', 'o2', 'e1', 'e2'], references ['bend_euler_1', 'bend_euler_2', 'bend_euler_3', 'bend_euler_4', 'bend_euler_5', 'bend_euler_6', 'straight_4', 'straight_5', 'straight_6', 'bend_euler_7', 'bend_euler_8', 'straight_7', 'straight_8', 'straight_9', 'sytl', 'syl', 'sxt', 'sxb', 'cp1', 'cp2'], 0 polygons
[62]:
def straight(wl=1.5, length=10.0, neff=2.4) -> sax.SDict:
    wl0 = 1.5  # center wavelength for which the waveguide model is defined
    return sax.reciprocal({("o1", "o2"): jnp.exp(2j * jnp.pi * neff * length / wl)})


def mmi1x2() -> sax.SDict:
    """Returns a perfect 1x2 splitter."""
    return sax.reciprocal(
        {
            ("o1", "o2"): 0.5**0.5,
            ("o1", "o3"): 0.5**0.5,
        }
    )


def bend_euler(wl=1.5, length=20.0) -> sax.SDict:
    """Returns bend Sparameters with reduced transmission compared to a straight."""
    return {k: 0.99 * v for k, v in straight(wl=wl, length=length).items()}


def phase_shifter_heater(
    wl: float = 1.55,
    neff: float = 2.34,
    voltage: float = 0,
    length: float = 10,
    loss: float = 0.0,
) -> sax.SDict:
    """Returns simple phase shifter model"""
    deltaphi = voltage * jnp.pi
    phase = 2 * jnp.pi * neff * length / wl + deltaphi
    amplitude = jnp.asarray(10 ** (-loss * length / 20), dtype=complex)
    transmission = amplitude * jnp.exp(1j * phase)
    sdict = sax.reciprocal(
        {
            ("o1", "o2"): transmission,
        }
    )
    return sdict


models = {
    "bend_euler": bend_euler,
    "mmi1x2": mmi1x2,
    "straight": straight,
    "straight_heater_metal_undercut": phase_shifter_heater,
}
[63]:
mzi_component = gf.components.mzi_phase_shifter_top_heater_metal(
    delta_length=delta_length
)
netlist = mzi_component.get_netlist()
mzi_circuit, _ = sax.circuit(netlist=netlist, models=models)
S = mzi_circuit(wl=1.55)
S
{
    ('o1', 'o1'): Array(0.+0.j, dtype=complex64),
    ('o2', 'o2'): Array(0.+0.j, dtype=complex64),
    ('o1', 'o2'): Array(-0.08122635+0.72426844j, dtype=complex64),
    ('o2', 'o1'): Array(-0.08122632+0.72426844j, dtype=complex64)
}
[64]:
wl = np.linspace(1.5, 1.6, 256)
S = mzi_circuit(wl=wl)

plt.figure(figsize=(14, 4))
plt.title("MZI")
plt.plot(1e3 * wl, jnp.abs(S["o1", "o2"]) ** 2)
plt.xlabel("λ [nm]")
plt.ylabel("T")
plt.grid(True)
plt.show()
../../../_images/notebooks_plugins_sax_01_sax_87_0.png

Now you can tune the phase shift applied to one of the arms.

How do you find out what’s the name of the netlist component that you want to tune?

You can backannotate the netlist and read the labels on the backannotated netlist or you can plot the netlist

[65]:
mzi_component.plot_netlist()
<networkx.classes.graph.Graph object at 0x7f2bf2b119d0>
../../../_images/notebooks_plugins_sax_01_sax_89_1.png

As you can see the top phase shifter instance name sxt is hard to see on the netlist. You can also reconstruct the component using the netlist and look at the labels in klayout.

[66]:
mzi_yaml = mzi_component.get_netlist_yaml()
mzi_component2 = gf.read.from_yaml(mzi_yaml)
mzi_component2.plot(label_aliases=True)
2023-02-20 17:57:21.170 | INFO     | gdsfactory.technology.layer_views:__init__:785 - Importing LayerViews from YAML file: /home/runner/work/gdsfactory/gdsfactory/gdsfactory/generic_tech/layer_views.yaml.
/home/runner/work/gdsfactory/gdsfactory/gdsfactory/read/from_yaml.py:792: UserWarning: YAML defined: (bend_euler_6, bend_euler_1, bend_euler_8, straight_7, straight_4, bend_euler_3, bend_euler_5, bend_euler_7, cp2, bend_euler_4, bend_euler_2) with both connection and placement. Please use one or the other.
  warnings.warn(
../../../_images/notebooks_plugins_sax_01_sax_91_1.png

The best way to get a deterministic name of the instance is naming the reference on your Pcell.

[67]:
voltages = np.linspace(-1, 1, num=5)
voltages = [-0.5, 0, 0.5]

for voltage in voltages:
    S = mzi_circuit(
        wl=wl,
        sxt={"voltage": voltage},
    )
    plt.plot(wl * 1e3, abs(S["o1", "o2"]) ** 2, label=str(voltage))
    plt.xlabel("λ [nm]")
    plt.ylabel("T")
    plt.ylim(-0.05, 1.05)
    plt.grid(True)

plt.title("MZI vs voltage")
plt.legend()
../../../_images/notebooks_plugins_sax_01_sax_93_0.png

Optimization#

You can optimize an MZI to get T=0 at 1530nm. To do this, you need to define a loss function for the circuit at 1550nm. This function should take the parameters that you want to optimize as positional arguments:

[68]:
def straight(wl=1.5, length=10.0, neff=2.4) -> sax.SDict:
    wl0 = 1.5  # center wavelength for which the waveguide model is defined
    return sax.reciprocal({("o1", "o2"): jnp.exp(2j * jnp.pi * neff * length / wl)})


def mmi1x2():
    """Assumes a perfect 1x2 splitter"""
    return sax.reciprocal(
        {
            ("o1", "o2"): 0.5**0.5,
            ("o1", "o3"): 0.5**0.5,
        }
    )


def bend_euler(wl=1.5, length=20.0):
    """ "Let's assume a reduced transmission for the euler bend compared to a straight"""
    return {k: 0.99 * v for k, v in straight(wl=wl, length=length).items()}


models = {
    "bend_euler": bend_euler,
    "mmi1x2": mmi1x2,
    "straight": straight,
}
[69]:
delta_length = 30
mzi_component = gf.components.mzi(delta_length=delta_length)
mzi_circuit, _ = sax.circuit(netlist=mzi_component.get_netlist(), models=models)
S = mzi_circuit(wl=1.55)
S
{
    ('o1', 'o1'): Array(0.+0.j, dtype=complex64),
    ('o2', 'o2'): Array(0.+0.j, dtype=complex64),
    ('o1', 'o2'): Array(-0.11767103-0.08549857j, dtype=complex64),
    ('o2', 'o1'): Array(-0.11767103-0.08549854j, dtype=complex64)
}
[70]:
wl = np.linspace(1.5, 1.6, 256)
S = mzi_circuit(wl=wl)

plt.figure(figsize=(14, 4))
plt.title("MZI")
plt.plot(1e3 * wl, jnp.abs(S["o1", "o2"]) ** 2)
plt.xlabel("λ [nm]")
plt.ylabel("T")
plt.plot([1530, 1530], [0, 1])
plt.grid(True)
plt.show()
../../../_images/notebooks_plugins_sax_01_sax_97_0.png

GDSFactory autonames component names for GDS and for netlists uses an incremental name for easier addressing of the references.

[71]:
netlist = mzi_component.get_netlist()
c = gf.read.from_yaml(netlist)
c
2023-02-20 17:57:22.085 | INFO     | gdsfactory.technology.layer_views:__init__:785 - Importing LayerViews from YAML file: /home/runner/work/gdsfactory/gdsfactory/gdsfactory/generic_tech/layer_views.yaml.
/home/runner/work/gdsfactory/gdsfactory/gdsfactory/read/from_yaml.py:792: UserWarning: YAML defined: (bend_euler_6, bend_euler_1, bend_euler_8, bend_euler_3, bend_euler_5, straight_8, bend_euler_7, straight_5, cp2, bend_euler_4, bend_euler_2) with both connection and placement. Please use one or the other.
  warnings.warn(
mzi_delta_length30_a853854b: uid 3e9fb629, ports ['o1', 'o2'], references ['bend_euler_1', 'bend_euler_2', 'bend_euler_3', 'bend_euler_4', 'bend_euler_5', 'bend_euler_6', 'bend_euler_7', 'bend_euler_8', 'cp1', 'cp2', 'straight_10', 'straight_5', 'straight_6', 'straight_7', 'straight_8', 'straight_9', 'sxb', 'sxt', 'syl', 'sytl'], 0 polygons

From this we see that we will need to change syl and straight_9.

[72]:
mzi_component = gf.components.mzi(
    delta_length=delta_length,
)
mzi_circuit, _ = sax.circuit(
    netlist=mzi_component.get_netlist(),
    models=models,
)


@jax.jit
def loss_fn(delta_length):
    S = mzi_circuit(
        wl=1.53,
        syl={
            "length": delta_length / 2 + 2,
        },
        straight_9={
            "length": delta_length / 2 + 2,
        },
    )
    return (abs(S["o1", "o2"]) ** 2).mean()
[73]:
%time loss_fn(20.0)
CPU times: user 1.31 s, sys: 7.72 ms, total: 1.32 s
Wall time: 1.3 s
Array(0.14018838, dtype=float32)

You can use this loss function to define a grad function which works on the parameters of the loss function:

[74]:
grad_fn = jax.jit(
    jax.grad(
        loss_fn,
        argnums=0,  # JAX gradient function for the first positional argument, jitted
    )
)

Next, you need to define a JAX optimizer, which on its own is nothing more than three more functions:

  1. an initialization function with which to initialize the optimizer state

  2. an update function which will update the optimizer state (and with it the model parameters).

  3. a function with the model parameters given the optimizer state.

[75]:
initial_delta_length = 30.0
init_fn, update_fn, params_fn = opt.adam(step_size=0.1)
state = init_fn(initial_delta_length)
[76]:
def step_fn(step, state):
    settings = params_fn(state)
    loss = loss_fn(settings)
    grad = grad_fn(settings)
    state = update_fn(step, grad, state)
    return loss, state
[77]:
range_ = trange(100)
for step in range_:
    loss, state = step_fn(step, state)
    range_.set_postfix(loss=f"{loss:.6f}")
[78]:
delta_length = params_fn(state)
delta_length
Array(30.282084, dtype=float32)
[79]:
S = mzi_circuit(
    wl=wl,
    syl={"length": delta_length / 2 + 2},
    straight_9={"length": delta_length / 2 + 2},
)
plt.figure(figsize=(14, 4))
plt.plot(wl * 1e3, abs(S["o1", "o2"]) ** 2)
plt.xlabel("λ [nm]")
plt.ylabel("T")
plt.ylim(-0.05, 1.05)
plt.plot([1530, 1530], [0, 1])
plt.grid(True)
plt.show()
../../../_images/notebooks_plugins_sax_01_sax_110_0.png

The minimum of the MZI is perfectly located at 1530nm.

Hierarchical circuits#

You can also simulate hierarchical circuits, such as lattice of MZI interferometers.

[80]:
@gf.cell
def mzis(delta_length=10):
    c = gf.Component()
    c1 = c << gf.components.mzi(delta_length=delta_length)
    c2 = c << gf.components.mzi(delta_length=delta_length)
    c2.connect("o1", c1.ports["o2"])

    c.add_port("o1", port=c1.ports["o1"])
    c.add_port("o2", port=c2.ports["o2"])
    return c


def straight(wl=1.5, length=10.0, neff=2.4) -> sax.SDict:
    """Straight model."""
    return sax.reciprocal({("o1", "o2"): jnp.exp(2j * jnp.pi * neff * length / wl)})


def mmi1x2():
    """Assumes a perfect 1x2 splitter."""
    return sax.reciprocal(
        {
            ("o1", "o2"): 0.5**0.5,
            ("o1", "o3"): 0.5**0.5,
        }
    )


def bend_euler(wl=1.5, length=20.0):
    """Assumes reduced transmission for the euler bend compared to a straight."""
    return {k: 0.99 * v for k, v in straight(wl=wl, length=length).items()}


models = {
    "bend_euler": bend_euler,
    "mmi1x2": mmi1x2,
    "straight": straight,
}


c2 = mzis()
c2
mzis: uid dfe3a099, ports ['o1', 'o2'], references ['mzi_1', 'mzi_2'], 0 polygons
[81]:
c2.plot_netlist_flat()
<networkx.classes.graph.Graph object at 0x7f2c2ec48d30>
../../../_images/notebooks_plugins_sax_01_sax_114_1.png
[82]:
c1 = gf.components.mzi(delta_length=10)
c1
mzi_delta_length10: uid d21b26ba, ports ['o1', 'o2'], references ['bend_euler_1', 'bend_euler_2', 'bend_euler_3', 'bend_euler_4', 'bend_euler_5', 'bend_euler_6', 'straight_5', 'straight_6', 'straight_7', 'bend_euler_7', 'bend_euler_8', 'straight_8', 'straight_9', 'straight_10', 'sytl', 'syl', 'sxt', 'sxb', 'cp1', 'cp2'], 0 polygons
[83]:
c1.plot_netlist()
<networkx.classes.graph.Graph object at 0x7f2bf43c82b0>
../../../_images/notebooks_plugins_sax_01_sax_116_1.png
[84]:
wl = np.linspace(1.5, 1.6)
netlist1 = c1.get_netlist_recursive()
circuit1, _ = sax.circuit(netlist=netlist1, models=models)
S1 = circuit1(wl=wl)

netlist2 = c2.get_netlist_recursive()
circuit2, _ = sax.circuit(netlist=netlist2, models=models)
S2 = circuit2(wl=wl)

plt.figure(figsize=(14, 4))
plt.plot(1e3 * wl, jnp.abs(S1["o1", "o2"]) ** 2, label="1 MZI")
plt.plot(1e3 * wl, jnp.abs(S2["o1", "o2"]) ** 2, label="2 MZI")
plt.xlabel("λ [nm]")
plt.ylabel("T")
plt.grid(True)
plt.legend()
plt.show()
../../../_images/notebooks_plugins_sax_01_sax_117_0.png
[ ]: